package face6wap;

import org.deeplearning4j.nn.conf.BackpropType;
import org.deeplearning4j.nn.conf.ComputationGraphConfiguration.GraphBuilder;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.conf.graph.StackVertex;
import org.deeplearning4j.nn.conf.layers.DenseLayer;
import org.deeplearning4j.nn.conf.layers.OutputLayer;
import org.deeplearning4j.nn.graph.ComputationGraph;
import org.deeplearning4j.nn.weights.WeightInit;
import org.deeplearning4j.optimize.listeners.ScoreIterationListener;
import org.nd4j.linalg.activations.Activation;
import org.nd4j.linalg.activations.impl.ActivationLReLU;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.rng.distribution.impl.NormalDistribution;
import org.nd4j.linalg.dataset.MultiDataSet;
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.learning.config.Adam;
import org.nd4j.linalg.lossfunctions.LossFunctions;
import util.ShowUtils;
import util.ShowUtilsHanf;

import java.io.File;
import java.util.Map;

/**
 * https://blog.csdn.net/weixin_44791964/article/details/103758751
 * lsgan
 修改 loss测试
 训练思路
 LSGAN的训练思路分为如下几个步骤：
 1、随机选取batch_size个真实的图片。
 2、随机生成batch_size个N维向量，传入到Generator中生成batch_size个虚假图片。
 3、真实图片的label为1，虚假图片的label为0，将真实图片和虚假图片当作训练集传入到Discriminator中进行训练，训练的loss使用均方差。
 4、将虚假图片的Discriminator预测结果与1的对比作为loss对Generator进行训练（与1对比的意思是，如果Discriminator将虚假图片判断为1，说明这个生成的图片很“真实”），这个loss同样使用均方差。
 5、修改展示方式后试验成功
 */
public class Gan_dis_ls_loss {

	static double lr = 0.0002;
	static String model = "F:/face/gan.zip";
	public static void main(String[] args) throws Exception {
		Nd4j.getMemoryManager().setAutoGcWindow(100 * 1000);

		final GraphBuilder dis = new NeuralNetConfiguration.Builder().updater(Adam.builder().learningRate(lr).beta1(0.5).build())
				.weightInit(WeightInit.XAVIER).graphBuilder().backpropType(BackpropType.Standard)
				.addInputs("input1", "input2")
				.addVertex("stack", new StackVertex(), "input2", "input1")
				.addLayer("d1",
						new DenseLayer.Builder().nIn(28 * 28).nOut(256).activation(new ActivationLReLU(0.2))
								.weightInit(WeightInit.XAVIER).build(),
						"stack")
				.addLayer("d2",
						new DenseLayer.Builder().nIn(256).nOut(128).activation(new ActivationLReLU(0.2))
								.weightInit(WeightInit.XAVIER).build(),
						"d1")
				.addLayer("d3",
						new DenseLayer.Builder().nIn(128).nOut(128).activation(new ActivationLReLU(0.2))
								.weightInit(WeightInit.XAVIER).build(),
						"d2")
				.addLayer("out", new OutputLayer.Builder(LossFunctions.LossFunction.MSE).nIn(128).nOut(1)
						.activation(Activation.TANH).build(), "d3")
				.setOutputs("out");

		final GraphBuilder gen = new NeuralNetConfiguration.Builder().updater(Adam.builder().learningRate(lr).beta1(0.5).build())
				.weightInit(WeightInit.XAVIER).graphBuilder().backpropType(BackpropType.Standard)
				.addInputs("input1", "input2")
				.addLayer("g1",
						new DenseLayer.Builder().nIn(28*28).nOut(128).activation(new ActivationLReLU(0.2))
								.weightInit(WeightInit.XAVIER).build(),
						"input1")
				.addLayer("g2",
						new DenseLayer.Builder().nIn(128).nOut(512).activation(new ActivationLReLU(0.2))
								.weightInit(WeightInit.XAVIER).build(),
						"g1")
				.addLayer("g3",
						new DenseLayer.Builder().nIn(512).nOut(28 * 28).activation(new ActivationLReLU(0.2))
								.weightInit(WeightInit.XAVIER).build(),
						"g2")
				.addVertex("stack", new StackVertex(), "input2", "g3")
				.addLayer("d1",
						new DenseLayer.Builder().nIn(28 * 28).nOut(256).activation(new ActivationLReLU(0.2))
								.weightInit(WeightInit.XAVIER).build(),
						"stack")
				.addLayer("d2",
						new DenseLayer.Builder().nIn(256).nOut(128).activation(new ActivationLReLU(0.2))
								.weightInit(WeightInit.XAVIER).build(),
						"d1")
				.addLayer("d3",
						new DenseLayer.Builder().nIn(128).nOut(128).activation(new ActivationLReLU(0.2))
								.weightInit(WeightInit.XAVIER).build(),
						"d2")
				.addLayer("out", new OutputLayer.Builder(LossFunctions.LossFunction.MSE).nIn(128).nOut(1)
						.activation(Activation.TANH).build(), "d3")
				.setOutputs("out");

		ComputationGraph net = new ComputationGraph(dis.build());
		ComputationGraph net1 = new ComputationGraph(gen.build());

		net.init();
		net1.init();

		System.out.println(net.summary());
		System.out.println(net1.summary());

		net.setListeners(new ScoreIterationListener(100));
		DataSetIterator train = new MnistOneDataSetIterator(30, true, 12345);
		//DataSetIterator train = new MnistDataSetIterator(30, true, 12345);
		/*NormalizerMinMaxScaler ss = new NormalizerMinMaxScaler(-1,1);
		ss.fit(train);*/
		//按垂直方向（行顺序）堆叠数组构成一个新的数组
		INDArray labelD = Nd4j.vstack(Nd4j.zeros(30, 1), Nd4j.ones(30, 1));
 
		INDArray labelG = Nd4j.zeros(60, 1);
 
		for (int i = 1; i <= 100000; i++) {
			if (!train.hasNext()) {
				train.reset();
			}
			INDArray trueExp = train.next().getFeatures();
			/*long[] shape = trueExp.shape();*/

			INDArray z = Nd4j.rand(new NormalDistribution(),new long[] { 30, 28*28 });
			Map<String, INDArray> map1 = net1.feedForward(new INDArray[] {z,z}, false);
			INDArray indArray = map1.get("g3");
			MultiDataSet dataSetD = new MultiDataSet(new INDArray[] {indArray, trueExp },
					new INDArray[] { labelD });
			for(int m=0;m<10;m++){
				trainD(net, dataSetD);
			}
			net1.getLayer("d1").setParam("W", net.getLayer("d1").getParam("W"));
			net1.getLayer("d1").setParam("b", net.getLayer("d1").getParam("b"));
			net1.getLayer("d2").setParam("W", net.getLayer("d2").getParam("W"));
			net1.getLayer("d2").setParam("b", net.getLayer("d2").getParam("b"));
			net1.getLayer("d3").setParam("W", net.getLayer("d3").getParam("W"));
			net1.getLayer("d3").setParam("b", net.getLayer("d3").getParam("b"));
			net1.getLayer("out").setParam("W", net.getLayer("out").getParam("W"));
			net1.getLayer("out").setParam("b", net.getLayer("out").getParam("b"));

			/*Map<String, INDArray> validMap = net.feedForward(
					new INDArray[] {indArray,indArray}, false);
			INDArray valid = validMap.get("out");// .reshape(20,28,28);
			MultiDataSet dataSetG = new MultiDataSet(new INDArray[] { z, z },
					new INDArray[] { valid });*/
			z = Nd4j.rand(new NormalDistribution(),new long[] { 30, 28 * 28 });
			MultiDataSet dataSetG = new MultiDataSet(new INDArray[] { z, trueExp },
					new INDArray[] { labelG });
			trainG(net1, dataSetG);




		/*	net.getLayer("g1").setParam("W", net1.getLayer("g1").getParam("W"));
			net.getLayer("g1").setParam("b", net1.getLayer("g1").getParam("b"));
			net.getLayer("g2").setParam("W", net1.getLayer("g2").getParam("W"));
			net.getLayer("g2").setParam("b", net1.getLayer("g2").getParam("b"));
			net.getLayer("g3").setParam("W", net1.getLayer("g3").getParam("W"));
			net.getLayer("g3").setParam("b", net1.getLayer("g3").getParam("b"));*/

			if (i % 10 == 0) {

				INDArray noise =  Nd4j.rand(new NormalDistribution(),new long[] { 2, 28*28 });
				INDArray noise1 =  Nd4j.rand(new NormalDistribution(),new long[] { 2, 28 * 28 });
				/*INDArray[] samps =  gen.output(noise);*/
				/*long[] shpaes = samps[0].shape();
				INDArray[] samples = new INDArray[(int)samps.length];
				for (int j = 0; j < samps.length; j++) {
					samples[j] = samps[j];
	  			}*/

				Map<String, INDArray> map2 = net1.feedForward(
						new INDArray[] {noise,noise1}, false);
				INDArray indArray2 = map2.get("g3");// .reshape(20,28,28);
				INDArray[] samples = new INDArray[(int)indArray2.size(0)];

				samples[0] = indArray2;

				ShowUtilsHanf.visualize(samples,"拆分");
			}
			if (i % 10000 == 0) {
			   net.save(new File(model), true);
			}
 
		}
 
	}
 	// 判别模型  D(x)
	public static void trainD(ComputationGraph net, MultiDataSet dataSet) {
		net.setLearningRate("d1", lr);
		net.setLearningRate("d2", lr);
		net.setLearningRate("d3", lr);
		net.setLearningRate("out", lr);
		net.fit(dataSet);
	}
	//生成模型 g(z)
	public static void trainG(ComputationGraph net, MultiDataSet dataSet) {
		net.setLearningRate("g1", lr);
		net.setLearningRate("g2", lr);
		net.setLearningRate("g3", lr);
		net.setLearningRate("d1", 0);
		net.setLearningRate("d2", 0);
		net.setLearningRate("d3", 0);
		net.setLearningRate("out", 0);
		net.fit(dataSet);
	}
}